package it.uniroma2.dtk.dt;

import it.uniroma2.dtk.common.Spectrum;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.math.MatrixUtils;
import it.uniroma2.util.tree.Tree;
import it.uniroma2.util.vector.RandomVectorGenerator;
import it.uniroma2.util.vector.VectorComposer;
import it.uniroma2.util.vector.VectorProvider;

/**
 * @author Lorenzo Dell'Arciprete
 * Default implementation of DT interface, according to the procedures provided in the paper.
 * This class abstracts from the actual ideal composition function implementation.
 */
public abstract class DefaultAbstractDT implements DT {

	protected VectorProvider vectorProvider;	// RandomVectorGenerator is the default implementation.
	
	protected int vectorSize;
	protected boolean usePos;
	protected boolean lexicalized;
	protected double lambdaSq = 1;	// Square root of lambda parameter
	
	/**
	 * @param randomOffset - random seed for vector generation and composition
	 * @param vectorsSize - the dimension of the desired vector space
	 * @param usePos - if true, POS information will be considered for the tree node labels
	 * @param lexicalized - if false, leaf nodes will be ignored 
	 * @param lambda - the value of the lambda decaying factor
	 * @throws Exception
	 */
	public DefaultAbstractDT(int randomOffset, int vectorsSize, boolean usePos, boolean lexicalized, double lambda) throws Exception {
		vectorProvider = new RandomVectorGenerator(vectorsSize, randomOffset);
		this.vectorSize = vectorsSize;
		this.usePos = usePos;
		this.lexicalized = lexicalized;
		lambdaSq = Math.sqrt(lambda);
	}
	
	public DefaultAbstractDT(int randomOffset, int vectorsSize, boolean usePos, boolean lexicalized) throws Exception {
		this(randomOffset, vectorsSize, usePos, lexicalized, 1);
	}
	
	public DefaultAbstractDT(int randomOffset, int vectorsSize, double lambda) throws Exception {
		this(randomOffset, vectorsSize, false, true, lambda);
	}
	
	/**
	 * The vector composition function
	 */
	public abstract double[] op(double[] v1, double[] v2) throws Exception;
	
	public int getVectorSize() {return vectorSize;}
	
	public VectorProvider getVectorProvider() {return vectorProvider;}
	
	public double[] dt(Tree x) {
		//The computation of DT is performed according to Equation 6
		//The Spectrum object will collect the values of s(n) during the course of its computation
		Spectrum result = new Spectrum();
		result.setVector(MatrixUtils.uniformVector(vectorSize, 0));
		try {
			sRecursive(x, result);
		}
		catch(Exception e) {
			e.printStackTrace();
		}
		return result.getVector();
	}
	
	/**
	 * Recursive computation of function s(n) for the root of the input tree.
	 * To save time and space, the computed s(n) is directly added to the final sum of Equation 6.
	 * @param node - the input tree or, equivalently, its root node
	 * @param sum - the object collecting the sum of s(n) for each node in the tree
	 * @return s(node)
	 * @throws Exception
	 */
	protected double[] sRecursive(Tree node, Spectrum sum) throws Exception {
		//If node is terminal, s(n) is the zero vector
		double[] result = MatrixUtils.uniformVector(vectorSize, 0);
		if (!node.isTerminal()) {
			boolean preterminal = true;
			//Recursive computation of s(n). The arbitrary composition order is n#(...((c1#c2)#c3)#...#cn)
			//(note that the composition order might be different, as long as it is consistent)
			for (int i=0; i<node.getChildren().size(); i++) {
				Tree child = node.getChildren().get(i);
				double[] childVector = getLabelVector(child);
				if (!child.isTerminal()) {
					//according to the definition of s(n), childVector is (vec(child) + sqrt(lambda)*s(child))
					childVector = VectorComposer.sum(getLabelVector(child), ArrayMath.scalardot(lambdaSq, sRecursive(child, sum)));
					preterminal = false;
				}
				result = (i == 0) ? childVector : op(result, childVector);
			}
			if (!preterminal || lexicalized) {
				result = op(getLabelVector(node), result);
				//Adding s(node) to the final sum
				sum.setVector(VectorComposer.sum(sum.getVector(), result));
			}
			else
				//If we are not in a lexicalized setting, a pre-terminal node is treated as a terminal node  
				result = MatrixUtils.uniformVector(vectorSize, 0);
		}
		return result;
	}
	
	public double[] dtf(Tree x) {
		//This computes a DTF according to Definition 1. 
		//Parameter lambda is also considered in the measure described for Theorem 5. 
		double[] result = MatrixUtils.uniformVector(vectorSize, 0);
		try {
			//x is a tree fragment; it may have non-terminal nodes as leafs
			if (x.getChildren().size() == 0) {
				//In this case, return value will be mistakenly multiplied by lambdaSq, 
				//so it must be divided by lambdaSq to compensate 
				result = ArrayMath.scalardot(1/lambdaSq, getLabelVector(x));
			}
			else {
				//The composition order is slightly different from the one in Definition 1, it is n#(...((c1#c2)#c3)#...#cn)
				result = ArrayMath.scalardot(lambdaSq, dtf(x.getChildren().get(0)));
				for (int i=1; i < x.getChildren().size(); i++)
					result = op(result, ArrayMath.scalardot(lambdaSq, dtf(x.getChildren().get(i))));
				result = op(getLabelVector(x), result);
			}
		}
		catch(Exception e) {
			e.printStackTrace();
		}
		return result;
	}
	
	public double[] getLabelVector(Tree node) throws Exception {
		if (usePos) return vectorProvider.getVector(node.getUsePosLabel());
		else return vectorProvider.getVector(node.getRootLabel());
	}
	
	public double getLambda() {
		return lambdaSq*lambdaSq;
	}
	
	public void setLambda(double lambda) {
		lambdaSq = Math.sqrt(lambda);
	}

	public boolean isUsePos() {
		return usePos;
	}

	public void setUsePos(boolean usePos) {
		this.usePos = usePos;
	}

	public boolean isLexicalized() {
		return lexicalized;
	}

	public void setLexicalized(boolean lexicalized) {
		this.lexicalized = lexicalized;
	}
}
